"""Get messages from DW and forward them to upstream KCIDB."""
import argparse
import contextlib
import os

from cki_lib import messagequeue
from cki_lib import metrics
from cki_lib import misc
from cki_lib.kcidb import ValidationError
from cki_lib.kcidb import validate_extended_kcidb_schema
from cki_lib.logger import get_logger
import kcidb
import prometheus_client as prometheus

LOGGER = get_logger(__name__)

METRIC_OBJECT_BUFFERED = prometheus.Histogram(
    'cki_kcidb_object_buffered', 'Number of KCIDB objects buffered for forwarding')
METRIC_OBJECT_FORWARDED = prometheus.Counter(
    'cki_kcidb_object_forwarded', 'Number of KCIDB objects forwarded')


class KCIDBMessage:
    # pylint: disable=too-many-instance-attributes
    """A KCIDB message."""

    def __init__(self):
        """Create an object."""
        self.kcidb_client = kcidb.Client(project_id=os.environ.get('KCIDB_PROJECT_ID'),
                                         topic_name=os.environ.get('KCIDB_TOPIC_NAME'))
        self.checkouts = []
        self.builds = []
        self.tests = []
        self.issues = []
        self.incidents = []
        self.version = None
        self.msg_ack_callbacks = []
        self.max_batch_size = misc.get_env_int('MAX_BATCH_SIZE', 500)

    def add(self, obj_type, obj, ack_fn):
        """Add the object to the correct list."""
        LOGGER.debug('Adding to KCIDBMessage: "%s" (%s)', obj_type, obj)
        obj_lists = {
            "checkout": self.checkouts,
            "build": self.builds,
            "test": self.tests,
            "issue": self.issues,
            "incident": self.incidents,
        }
        obj_lists[obj_type].append(obj)
        self.msg_ack_callbacks.append(ack_fn)

        METRIC_OBJECT_BUFFERED.observe(len(self))

        if not self.version:
            self.version = misc.get_nested_key(obj, 'misc/kcidb/version')

        for key in ('kcidb', 'is_public'):
            with contextlib.suppress(KeyError):
                del obj["misc"][key]

    def _sanitize_test(self):
        """Remove spaces from the test path."""
        for test in self.tests:
            path = test.get('path')
            if path:
                test['path'] = test['path'].replace(' ', '_')

    @property
    def encoded(self):
        """Return the KCIDB message object."""
        self._sanitize_test()
        data = {
            "version": self.version,
            "checkouts": self.checkouts,
            "builds": self.builds,
            "tests": self.tests,
            "issues": self.issues,
            "incidents": self.incidents,
        }

        # Validating against the CKI schema is not required, but brings consistency
        validate_extended_kcidb_schema(data, raise_for_cki=False)

        return data

    def clear(self):
        """Remove all elements."""
        LOGGER.debug('Clearing KCIDBMessage elements')
        self.checkouts = []
        self.builds = []
        self.tests = []
        self.issues = []
        self.incidents = []
        self.version = None
        self.msg_ack_callbacks = []

    def ack_messages(self):
        """Ack all messages contained on this object."""
        LOGGER.debug('Acking %i messages', len(self.msg_ack_callbacks))
        for ack_fn in self.msg_ack_callbacks:
            ack_fn()

    def submit(self):
        """Upload message to upstream kcidb."""
        LOGGER.debug('Submitting KCIDBMessage elements')
        if misc.is_production():
            try:
                self.kcidb_client.submit(self.encoded)
            except ValidationError:
                LOGGER.exception(
                    "Tried to submit invalid KCIDB data upstream. The message will be acked.")
            else:
                LOGGER.info('%i elements submitted', len(self))
                METRIC_OBJECT_FORWARDED.inc(amount=len(self))
        else:
            LOGGER.info('production mode would submit %i elements', len(self))
        self.ack_messages()
        self.clear()

    def __len__(self):
        """Return number of elements."""
        return (
            len(self.checkouts)
            + len(self.builds)
            + len(self.tests)
            + len(self.issues)
            + len(self.incidents)
        )

    def callback(self, body=None, ack_fn=None, **_):  # noqa: C901,PLR0911
        # pylint: disable=too-many-return-statements
        """Process one received message."""
        if not body:
            if len(self):
                LOGGER.debug('Calling submit')
                self.submit()
            return

        object_type = body['object_type']
        object_id = misc.get_nested_key(body, 'object/id')

        LOGGER.info('Processing message for %s %s', object_type, object_id)

        if (status := body["status"]) not in {"new", "updated"}:
            LOGGER.info('Message status "%s" is not "new" or "updated", skipping', status)
            ack_fn()
            return

        if misc.get_nested_key(body, 'object/origin') != 'redhat':
            LOGGER.info('Object not created by CKI, skipping')
            ack_fn()
            return

        # only publish public production pipelines
        if (
            not misc.get_nested_key(body, "object/misc/is_public")
            or misc.get_nested_key(body, "object/misc/retrigger")
        ):
            LOGGER.info('Object is not public, skipping')
            ack_fn()
            return

        if not misc.get_nested_key(body, "object/misc/kcidb/version"):
            LOGGER.info("Message is missing the KCIDB version, skipping")
            ack_fn()
            return

        if object_type == "checkout":
            if misc.get_nested_key(body, "object/misc/scratch"):
                LOGGER.info("Object is just a scratch, skipping")
                ack_fn()
                return

            if "kernel" not in misc.get_nested_key(
                body, "object/misc/source_package_name", "kernel"
            ):
                LOGGER.info("Message is not about the kernel, skipping")
                ack_fn()
                return

            if (
                not misc.get_nested_key(body, "object/git_repository_url")
                # Brew builds have a fake value as git_commit_hash
                and misc.get_nested_key(body, "object/git_commit_hash")
            ):
                # Remove it to avoid collisions upstream
                del body['object']['git_commit_hash']

        self.add(object_type, body['object'], ack_fn)

        if len(self) > self.max_batch_size:
            LOGGER.debug('Reached MAX_BATCH_SIZE, calling submit')
            self.submit()


def main(args: list[str] | None = None) -> None:
    """Run the main CLI Interface."""
    parser = argparse.ArgumentParser()
    parser.parse_args(args)

    metrics.prometheus_init()

    message = KCIDBMessage()

    messagequeue.MessageQueue().consume_messages(
        os.environ.get('WEBHOOK_RECEIVER_EXCHANGE', 'cki.exchange.webhooks'),
        os.environ['DATAWAREHOUSE_KCIDB_FORWARDER_ROUTING_KEYS'].split(),
        message.callback,
        queue_name=os.environ.get('DATAWAREHOUSE_KCIDB_FORWARDER_QUEUE'),
        prefetch_count=0,
        inactivity_timeout=misc.get_env_int('RABBITMQ_TIMEOUT_S', 30),
        return_on_timeout=False,
        manual_ack=True)
